import os
import json
import inspect
import argparse
import pandas as pd
from tqdm.auto import tqdm
from sklearn.metrics import v_measure_score
from sklearn.model_selection import train_test_split
from sklearn.cluster import AgglomerativeClustering
from sklearn.ensemble import RandomForestClassifier

from _models.model import get_embedding_func_batched
from _datasets.data import DatasetConfig
from utils.transform_utils import *
from utils.string_utils import *
from utils.metrics import *


class ClusteringExperimentConfig:

    def __init__(
        self,
        dataset_name: str,
        num_examples: int,
        model_name: str = "BAAI/bge-small-en-v1.5",
        max_length: int = 8192,
    ):
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.dataset_config = DatasetConfig(dataset_name, num_examples)
        self.dataset = self.dataset_config.get_dataset(True, max_length)
        print(f"Dataset {dataset_name} loaded.")

        self.embedding_func = get_embedding_func_batched(model_name)
        self.similarity_data = pd.DataFrame(self.dataset)
        self.results = {}
        self.metrics = metrics

        # Create directory for model data if it doesn't exist
        self.model_data_path = os.path.join(
            "data", self.model_name.replace("/", "_")
        )  # Replacing '/' with '_' to avoid subdirectories
        os.makedirs(self.model_data_path, exist_ok=True)

    def run(self):
        self.generate_embeddings(
            embedding_func=self.embedding_func,
            **{"model_name": self.model_name, "use_gpu": True},
        )
        print("Generated embeddings.")

        self.calculate_similarities()
        print("Calculated similarities.")

        self.fit_ensembling()
        print("Fitted ensembling.")

        self.get_results()
        print("Got results.")

        # Save the similarity data to a CSV file in the model-specific directory
        data_file_path = (
            f"{self.model_data_path}/{self.dataset_config.name}_clustering.pkl"
        )
        self.similarity_data.to_pickle(data_file_path)
        print(f"Saved data to {data_file_path}.")

        # Save the results to a JSON file in the model-specific directory
        results_file_path = (
            f"{self.model_data_path}/{self.dataset_config.name}_clustering.json"
        )
        with open(results_file_path, "w") as f:
            self.results = {k: float(v) for k, v in self.results.items()}
            f.write(json.dumps(self.results))
        print(f"Saved results to {results_file_path}.")

    def generate_embeddings(self, embedding_func, **kwargs):
        # For models that are not from huggingface
        source_code = inspect.getsource(embedding_func)
        if not "huggingface" in source_code:
            kwargs["model"] = kwargs["model_name"]
            del kwargs["model_name"]
            del kwargs["use_gpu"]

        embeddings_column = f"embeddings_original"
        embeds = embedding_func(
            prompts=self.similarity_data["original"].dropna().tolist(),
            pbar=False,
            **kwargs,
        )
        self.similarity_data[embeddings_column] = (
            embeds if isinstance(embeds, list) else embeds.tolist()
        )

    def get_sim_matrix(self, sim_fn):
        n = len(self.similarity_data)
        sentences = self.similarity_data["original"].tolist()
        if sim_fn == cosine_similarity:
            sentences = self.similarity_data["embeddings_original"].tolist()

        matrix = np.zeros((n, n))
        # Compute the distance between each pair of strings
        for i in tqdm(range(n), desc=sim_fn.__name__):
            for j in range(n):
                matrix[i, j] = sim_fn(sentences[i], sentences[j])

        # Normalize the distance matrix
        if sim_fn == bm25_score:
            matrix = (matrix - matrix.min()) / (matrix.max() - matrix.min())

        return matrix

    def calculate_similarities(self):
        print("Calculating similarity matrices...")
        self.levenshtein_matrix = self.get_sim_matrix(levenshtein_ratio)
        print("\tCalculated levenshtein similarity")
        self.rouge_matrix = self.get_sim_matrix(rouge_score)
        print("\tCalculated rouge similarity")
        self.bm25_matrix = self.get_sim_matrix(bm25_score)
        print("\tCalculated bm25 similarity")
        self.jaccard_matrix = self.get_sim_matrix(jaccard_similarity)
        print("\tCalculated jaccard similarity")
        self.cosine_matrix = self.get_sim_matrix(cosine_similarity)
        print("\tCalculated cosine similarity")

        self.distances = [
            self.levenshtein_matrix,
            self.rouge_matrix,
            self.bm25_matrix,
            self.jaccard_matrix,
            self.cosine_matrix,
        ]
        self.metrics = ["levenshtein", "rouge", "bm25", "jaccard", "cosine"]

    def fit_ensembling(self):
        X = np.concatenate(self.distances).T
        y = self.similarity_data["labels"]

        scores = []
        for i in tqdm(range(1000), desc="Ensembling"):
            self.ensemble = RandomForestClassifier(random_state=i)
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=i
            )
            self.ensemble.fit(X_train, y_train)
            score = v_measure_score(self.ensemble.predict(X_test), y_test)
            scores.append(score)

        self.results["ensembled_similarity"] = np.mean(scores)

    def get_results(self):
        num_labels = len(np.unique(self.similarity_data["labels"]))
        clustering = AgglomerativeClustering(
            num_labels, metric="precomputed", linkage="complete"
        )

        for i in range(len(self.distances)):
            matrix = self.distances[i]
            metric = f"{self.metrics[i]}_similarity"
            clustering.fit(1 - matrix)
            score = v_measure_score(self.similarity_data["labels"], clustering.labels_)
            self.results[metric] = score

        return self.results


def main(
    dataset_name="biorxiv-clustering-p2p",
    num_examples=5,
    model_name="embed-english-v3.0",
    max_length=8192,
):
    exp_config = ClusteringExperimentConfig(
        dataset_name,
        num_examples,
        model_name,
        max_length,
    )
    exp_config.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default="biorxiv-clustering-p2p")
    parser.add_argument("--num_examples", type=int, default=5)
    parser.add_argument("--model_name", type=str, default="embed-english-v3.0")
    parser.add_argument("--max_length", type=int, default=8192)
    args = parser.parse_args()

    dataset_name = args.dataset_name
    num_examples = args.num_examples
    model_name = args.model_name
    max_length = args.max_length

    main(
        dataset_name,
        num_examples,
        model_name,
        max_length,
    )
